-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix llama model sdpa attention forward function masking bug when output_attentions=True #30652
Fix llama model sdpa attention forward function masking bug when output_attentions=True #30652
Conversation
A minimal example of this erroneous behavior can be reproduced via: from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,
device_map='cuda',
torch_dtype=torch.bfloat16
)
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer(["Today is the day I went to the store and ..."],
return_tensors="pt").to('cuda')
expanded_batch_size = 1
outputs = model.generate(
input_ids = inputs['input_ids'].expand(expanded_batch_size, -1),
attention_mask = inputs['attention_mask'].expand(expanded_batch_size, -1),
do_sample=False,
max_new_tokens=5,
return_dict_in_generate=True,
)
input_length = inputs.input_ids.shape[1]
sequences= outputs.sequences
for sequence in sequences:
decoded_sequence = tokenizer.decode(sequence)
print(decoded_sequence)
# separator
print('-'*20)
outputs = model.generate(
input_ids = inputs['input_ids'].expand(expanded_batch_size, -1),
attention_mask = inputs['attention_mask'].expand(expanded_batch_size, -1),
do_sample=False,
max_new_tokens=5,
return_dict_in_generate=True,
output_attentions=True, # ?!
)
input_length = inputs.input_ids.shape[1]
sequences= outputs.sequences
# garbage generated outputs since no masking is applied
for sequence in sequences:
decoded_sequence = tokenizer.decode(sequence)
print(decoded_sequence) |
…he same sdpa masking logic from llama)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch.
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
line 1127 needs to be ignored as well.- we need to add your small example script as a test! 🤗
@ArthurZucker Thanks for reviewing my pull request and all your work in maintaining this awesome repo! :) Regarding your comments:
p.s. There seem to be some CircleCI tests failing on the main branch... which are now failing after I merged. |
For 2. the test is already implemented, but I don't think it tests Potentially adding |
Feel free to rebase it might be fixed on main / be flaky |
Just did :) |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@ArthurZucker Let me know if you think this fix is ready for merging, or if you'd like to add the tests to the same PR! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to just add the test in this PR 😉
Alright - I made the addition of output_attentions=True to the sdpa equivalence test, as you suggested ;) (Black code re-formatting seems to have messed up the diff, but the changes are minimal...) @ArthurZucker - Let me know if there are any outstanding issues or if there is something else missing before merging ^^ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, let's make sure you rebase as Gemma was updated a bit and commit with [run-slow] so that slow tests are run!
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Merging once the CIs are all green!) |
@ArthurZucker thanks for your suggestions! I also propagated the same changes to the new jetmoe model. All default checks are now passing ^^ |
THanks for the fix |
…ut_attentions=True (#30652) * Fix llama model forward function with attention=True, same-length encoded sequence. * Fix style * propagate fix to modeling_cohere, gemma, dbrx, and olmo (which copy the same sdpa masking logic from llama) * Fix style * ignore unnecessary sdpa mask converter when output_attentions=True * add tests checking sdpa and eager outputs match when output_attentions=True * Split if statements in two lines Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fix formatting * Add fix to new jetmoe model * Add missing output_attentions argument to jetmoe mask creation --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@Aladoro thank you for detecting the issue and making |
…ut_attentions=True (huggingface#30652) * Fix llama model forward function with attention=True, same-length encoded sequence. * Fix style * propagate fix to modeling_cohere, gemma, dbrx, and olmo (which copy the same sdpa masking logic from llama) * Fix style * ignore unnecessary sdpa mask converter when output_attentions=True * add tests checking sdpa and eager outputs match when output_attentions=True * Split if statements in two lines Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fix formatting * Add fix to new jetmoe model * Add missing output_attentions argument to jetmoe mask creation --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
What does this PR do?
Very simple fix to a nasty issue I have recently encountered. Due to its simplicity, I opened a PR directly without raising an issue first to avoid redundancy. Please, let me know if I should also raise an issue, and I'll do that right away.
Description
When output_attentions is True, sdpa implementation's forward method calls the eager implementation's forward method. However, a None mask is still returned if sdpa's 'AttentionMaskConverter._ignore_causal_mask_sdpa' returns true (which occurs whenever the input is unmasked, as sdpa would defer the causal masking to the sdpa Pytorch implementation).
This inconsistency causes the model to run the eager implementation with no causal attention mask if the original input is unmasked (e.g., if a single input sequence is encoded or all encoded input sequences have the same length) and requires_attn=True.
Pull Request section?
documentation guidelines, and
here are tips on formatting docstrings.
Tagging @ArthurZucker and @younesbelkada